################
# FUNCTIONS TO ESTIMATE WEIGHTS FOR IPTW WITH MULTICATEGORICAL TREATMENTS AND BINARY OUTCOMES
# AUTHOR: MICHELLE TORRES
# PREPARED FOR: ESTIMATING CONTROLLED DIRECT EFFECTS USING MARGINAL STRUCTURAL MODELS
################

# FUNCTION FOR INVERSE PROBABILITY OF TREATMENT WEIGHTS
iptw <- function(y, X, A, t, type, ...){
  n <- length(y)
  sw.probs <- NULL
  ind <- 1
  if(type=="ordinal"){
    for (i in 1:t){
      temp.X <- X[[i]]
      temp.A <- as.factor(A[,i])
      if(ind==1){
        temp.Mat.n <- data.frame(temp.A)
        temp.Mat.d <- data.frame(temp.A, temp.X)
      }
      else{
        temp.A_1 <- A[,1:(i-1)]
        temp.Mat.n <- data.frame(temp.A, temp.A_1)
        temp.X <- cbind(temp.Mat.d, temp.X)[,-1]
        temp.Mat.d <- data.frame(temp.A, temp.A_1, temp.X)
      }
      model.n <- polr(as.factor(temp.A)~., data=temp.Mat.n)
      model.d <- polr(as.factor(temp.A)~., data=temp.Mat.d)
      fitted.n <- cbind(as.integer(as.numeric(temp.A)), model.n$fitted)
      fitted.d <- cbind(as.integer(as.numeric(temp.A)), model.d$fitted)
      probs.n <- apply(fitted.n, 1, function(x) x[(x[1]+1)])
      probs.d <- apply(fitted.d, 1, function(x) x[(x[1]+1)])
      sw <- probs.n/probs.d
      sw.probs <- cbind(sw.probs,sw)
      ind <- ind+1
    }
    final.weights <- apply(sw.probs,1,function(x) prod(x))
  } 
  else if(type=="logit"){
    for (i in 1:t){
      temp.X <- X[[i]]
      temp.A <- as.factor(A[,i])
      if(ind==1){
        temp.Mat.n <- data.frame(temp.A)
        temp.Mat.d <- data.frame(temp.A, temp.X)
      }
      else{
        temp.A_1 <- A[,1:(i-1)]
        temp.Mat.n <- data.frame(temp.A, temp.A_1)
        temp.Mat.d <- data.frame(temp.A, temp.A_1, temp.X)
      }
      model.n <- glm(as.factor(temp.A)~., data=temp.Mat.n, family = binomial(link="logit"))
      model.d <- glm(as.factor(temp.A)~., data=temp.Mat.d, family = binomial(link="logit"))
      fitted.n <- cbind(as.integer(as.numeric(temp.A)), 1-model.n$fitted, model.n$fitted)
      fitted.d <- cbind(as.integer(as.numeric(temp.A)), 1-model.d$fitted,  model.d$fitted)
      probs.n <- apply(fitted.n, 1, function(x) x[(x[1]+1)])
      probs.d <- apply(fitted.d, 1, function(x) x[(x[1]+1)])
      sw <- probs.n/probs.d
      sw.probs <- cbind(sw.probs,sw)
      ind <- ind+1
    }
    final.weights <- apply(sw.probs,1,function(x) prod(x))
  }
  else if (type=="multinom"){
    for (i in 1:t){
      temp.X <- X[[i]]
      temp.A <- as.factor(A[,i])
      if(ind==1){
        temp.Mat.n <- data.frame(temp.A)
        temp.Mat.d <- data.frame(temp.A, temp.X)
      }
      else{
        temp.A_1 <- A[,1:(i-1)]
        temp.Mat.n <- data.frame(temp.A, temp.A_1)
        temp.X <- cbind(temp.Mat.d, temp.X)[,-1]
        temp.Mat.d <- data.frame(temp.A, temp.A_1, temp.X)
      }
      require(nnet)
      model.n <- multinom(as.factor(temp.A)~., data=temp.Mat.n)
      model.d <- multinom(as.factor(temp.A)~., data=temp.Mat.d)
      index <- sapply(temp.Mat.n[,"temp.A"], 
                      function(x) which(levels(temp.Mat.n[,"temp.A"])==x))
      if(is.null(model.n$na.action)){
        fitted.n <- cbind(index, model.n$fitted)
        probs.n <- apply(fitted.n, 1, function(x) x[(x[1]+1)])
      }
      else{
        fitted.n <- cbind(index[-model.n$na.action], model.n$fitted.values)
        probs.n <- apply(fitted.n, 1, function(x) unlist(x)[(unlist(x)[1]+1)])
      }
      if(is.null(model.d$na.action)){
        fitted.d <- cbind(index, model.d$fitted)
        probs.d <- apply(fitted.d, 1, function(x) x[(x[1]+1)])
      }
      else{
        fitted.d <- cbind(index[-model.d$na.action], model.d$fitted.values)
        probs.d <- apply(fitted.d, 1, function(x) unlist(x)[(unlist(x)[1]+1)])
      }
      sw <- probs.n/probs.d
      sw.probs <- cbind(sw.probs,sw)
      ind <- ind+1
    }
    final.weights <- apply(sw.probs,1,function(x) prod(x))
  }
  else{
    print("Not supported!")
  }
  results <- list(model.n, model.d, fitted.n, fitted.d, cbind(probs.n,probs.d), final.weights)
  names(results)<- c("Model Numerator", "Model Denominator", "Fitted num.", "Fitted den.",
                     "Probs (num, dem)", "Stabilized weights")
  return(results)
}

### RETURN THE WEIGHTS FOR THE ANALYSIS
prepareIPTW <- function(data = ldata73_1[[1]],
                        Y = "polpart13_73",
                        A = atreat73_1,
                        X.covs =  xcovs73_1 ,
                        t = 2,
                        type.treat = "ordinal"
){
  seq.vec <- 1:t
  X.ls <- llply(seq.vec, function(x) data[,X.covs[[x]]])
  A.mat <- data[,A]
  Y.vec <- data[,Y]
  wts <- iptw(Y.vec, X.ls, A.mat, t, type.treat)[[6]]
  return(wts)
}


### ESTIMATE MSM, BOOTSTRAP STANDARD ERRORS AND RECALCULATE WEIGHTS IN EACH ITERATION OF THE BOOTSTRAPPING
### RETURN ALL COEFFICIENTS AND STANDARD ERRORS: MSM, UNDER CONTROL MODEL, SATURATED MODEL
bootMSM_binout <- function(data, Y, A, X.covs,
                           t = 2, formula, type.treat ="ordinal",
                           iter = 1000,
                           seed = 12345){
  temp_mat_msm <- NULL
  formula0 <- paste(formula, " + ", paste0(unlist(X.covs), collapse= " + "))  
  
  # Models
  # Naive (over)
  model0 <- glm(as.formula(formula0), data=data) 
  rn0 <- rownames(summary(model0)$coefficients)
  model0.1 <- glm(as.formula(formula), data=data) 
  rn0.1 <- rownames(summary(model0.1)$coefficients)
  
  # MSM
  msm.wts.gral = prepareIPTW(data, Y, A, X.covs, t, type.treat)
  wgt.data = data.frame(cbind(data, msm.wts.gral))
  msm_model <- glm(as.formula(formula), data=data, weights = msm.wts.gral)
  msm_coefs <- summary(msm_model)$coefficients[,1]
  
  temp_mat_msm = NULL
  set.seed(seed)
  for (i in 1:iter){
    n <- nrow(data)
    sample.data <- data[sample(1:n, n, replace=TRUE),]
    temp.weights <- prepareIPTW(sample.data, Y, A, X.covs, t, type.treat)
    temp.data <- data.frame(cbind(sample.data, temp.weights))
    
    model1 <- glm(as.formula(formula), data=temp.data, weights = temp.weights)
    temp_mat_msm <- rbind(temp_mat_msm, summary(model1)$coefficients[,1]) 
  }
  
  summary_msm <- cbind(msm_coefs, apply(temp_mat_msm,2,sd, na.rm=TRUE))
  summary_naive0 <- cbind(summary(model0)$coefficients[,1], summary(model0)$coefficients[,2])
  summary_naive1 <- cbind(summary(model0.1)$coefficients[,1], summary(model0.1)$coefficients[,2])
  fmat <- rbind(summary_msm, summary_naive0, summary_naive1)
  colnames(fmat) <- c("Coefficient", "SD")
  rownames(fmat) <- c(paste(rn0.1, "MSM", sep="."), 
                      paste(rn0, "NaiveSat", sep="."),
                      paste(rn0.1, "NaiveUnder", sep="."))
  return(fmat)
}

### FUNCTIONS FOR THE SIMULATIONS
# Inverse logit function
ilogit <- function(x){return(exp(x)/(1+exp(x)))}

# Function to simulate ordered categorical outcomes
# Takes etas (plain beta*X) and cuts for the categories
simOLOG <- function(etas, cuts){
  temp <- lapply(cuts, function(x) exp(x-etas)/(1+exp(x-etas)))
  probs_temp <- do.call(cbind,temp)
  probs <- cbind(probs_temp[,1],probs_temp[,2]-probs_temp[,1],1-(probs_temp[,2]))
  return(probs)
}

# Function to calculate weights and return results from outcome model that will later be bootstrapped
bootWeights <- function(data,indices,formulas, outmodel, method, treat_ind, num_cats,out='std', num_knots,trim=FALSE,...){
  data2 = data[indices,]
  if(method=='splines'){
    # Numerator
    mods_numerator <- lapply(formulas[,1],function(x) gam(as.formula(x),data=data2,family=ocat(R=num_cats),k=num_knots))
    fit_numerator <- lapply(mods_numerator, function(x) predict(x, type='response'))
    # Denominator
    mods_denominator <- lapply(formulas[,2],function(x) gam(as.formula(x),data=data2,family=ocat(R=num_cats),k=num_knots))
    fit_denominator <- lapply(mods_denominator, function(x) predict(x, type="response"))
  }
  else if(method=='ologit'){
    # Numerator
    mods_numerator <- lapply(formulas[,1],function(x) polr(as.formula(x),data=data2))
    fit_numerator <- lapply(mods_numerator, function(x) predict(x,type = 'probs'))
    # Denominator
    mods_denominator <- lapply(formulas[,2],function(x) polr(as.formula(x),data=data2))
    fit_denominator <- lapply(mods_denominator, function(x) predict(x,type = 'probs'))
  }
  else if(method=='rf'){
    library(dummies)
    # Random forest
    newdat = data2
    newdat[,c('t0.0', 't0.1', 't0.2', 't1.0', 't1.1', 't1.2')] <- cbind(dummy(data2$t0), dummy(data2$t1))
    mod_num1 <- CoreModel(as.factor(t0) ~ x0, data=newdat, model="rf", selectionEstimator="MDL", minNodeWeightRF=5, rfNoTrees=100)
    mod_num2 <- CoreModel(as.factor(t1) ~ t0.0+t0.1+t0.2+x0+x1, data=newdat, model="rf", selectionEstimator="MDL", minNodeWeightRF=5, rfNoTrees=100)
    mod_den1 <- polr(as.factor(t0)~1, data=newdat)
    mod_den2 <- CoreModel(as.factor(t1) ~ t0.0+t0.1+t0.2, data=newdat, model="rf", selectionEstimator="MDL", minNodeWeightRF=5, rfNoTrees=100)
    fit_num1 <- predict(mod_num1, newdata=data.frame(newdat), type="probability")
    fit_num2 <- predict(mod_num2, newdata=data.frame(newdat), type="probability")
    fit_den1 <- predict(mod_den1, newdat, type='probs')
    fit_den2 <- predict(mod_den2, newdata=data.frame(newdat), type="probability")
    fit_denominator <- list(fit_den1, fit_den2)
    fit_numerator <- list(fit_num1, fit_num2)
  }
  den <- lapply(1:length(treat_ind), function(y) apply(cbind(data2[,treat_ind[y]], fit_denominator[[y]]),1, function(x) x[(x[1]+1)]))
  num <- lapply(1:length(treat_ind), function(y) apply(cbind(data2[,treat_ind[y]], fit_numerator[[y]]),1, function(x) x[(x[1]+1)]))
  probs <- do.call(cbind, lapply(1:nrow(formulas), function(x) num[[x]]/den[[x]]))
  wts <- apply(probs, 1, prod)
  
  if(trim){
    mydsgn <- svydesign(ids=~1, data=data2, weights = wts)
    wts <- weights(trimWeights(mydsgn, upper=3))
  }
  
  # Outcome model
  out_m <- glm(as.formula(outmodel), data=data.frame(data2), weights = wts, family=binomial(link='logit'))
  if(out=='std'){
    out2 <- summary(out_m)$coefficients[,2]  
  }
  else if(out=='beta'){
    out2 <- summary(out_m)$coefficients[,1]
  }
  else if(out=='wts'){
    out2 = c(mean(wts),sd(wts),min(wts),max(wts))
  }
  else if(out=='both'){
    out2 = list(c(mean(wts),sd(wts),min(wts),max(wts)), summary(out_m)$coefficients[,1])
  }
  else if(out=='summ'){
    out2 <- summary(out_m)$coefficients[,1:2]
  }
  else{
    print('No output available')
  }
  return(out2)
}

# Mini function to get probs
getprpr <- function(ests, addon=0){
  myvec <- c(y00=ests[1], y01=(ests[1] +ests[4]), y02=(ests[1] +ests[5]),
             y10=(ests[1] +ests[2]), y11=(ests[1] +ests[2] + ests[4] + ests[6+addon]), y12=(ests[1] +ests[2] + ests[5] + ests[8+addon]),
             y20=(ests[1] +ests[3]), y21=(ests[1] +ests[3] + ests[4] + ests[7+addon]), y22=(ests[1] +ests[3] + ests[5] + ests[9+addon]))
  myvec <- sapply(myvec, function(x) exp(x)/(1+exp(x)))
  return(round(myvec,3))
}

getcdes <- function(probs){
  myvec <- c(y10_y00 = probs[4]-probs[1], y20_y00 = probs[7]-probs[1], y20_y10 = probs[7]-probs[4],
             y11_y01 = probs[5]-probs[2], y21_y01 = probs[8]-probs[2], y21_y11 = probs[8]-probs[5],
             y12_y02 = probs[6]-probs[3], y22_y02 = probs[9]-probs[3], y22_y12 = probs[9]-probs[6])
  return(myvec)
}

### POOL ESTIMATES FROM MULTIPLE IMPUTATION
library(plyr)
poolEstimates <- function(list_tables, index=NULL, alpha=0.05){
  m <- length(list_tables)
  if(is.null(index)) {rowind = 1:nrow(list_tables[[1]])}
  else{rowind=index}
  coefs.imp <- laply(list_tables, function(x) x[rowind,1])
  se.imp <- laply(list_tables, function(x) x[rowind, 2])
  
  ## Calculate estimates
  mean.coef <- apply(coefs.imp,2, mean)
  between.var <- apply(coefs.imp,2,var)
  within.var <- apply(se.imp^2,2,mean)
  impute.se.vec <- sqrt(within.var + ((m+1)/m)*between.var)
  impute.df <- (m-1)*((1 + (1/(m+1)) * within.var/between.var)^2)
  
  ## Build the stats table
  out.table <- round(cbind(mean.coef,impute.se.vec,
                           mean.coef/impute.se.vec,
                           2 * (1 - pt(abs(mean.coef/impute.se.vec), impute.df)),
                           mean.coef-(qt(1-alpha/2,impute.df)*impute.se.vec),
                           mean.coef+(qt(1-alpha/2,impute.df)*impute.se.vec)),4)
  
  dimnames(out.table) <- list(rownames(list_tables[[1]])[rowind], 
                              c("Estimate","Std. Error","t value","Pr(>|t|)", "Lo CI", "Hi CI"))
  
  return(out.table)
}

# Function to have transparent colors
t_col <- function(color, percent = 50) {
  rgb.val <- col2rgb(color)
  t.col <- rgb(rgb.val[1], rgb.val[2], rgb.val[3],
               max = 255,
               alpha = (100-percent)*255/100)
  return(t.col)
}


## Function to simulate data
simdatnew2 <- function(n=1000,
                      beta.x0_t0=3,
                      beta.x0_t1=beta.x0_t0/2){
  
  
  
  # Sample size and IDs
  id <- 1:n
  
  #### SIMULATE DATA
  # First stage: X_0
  x0 <- sample(0:1, n, replace=TRUE, prob=c(0.6,0.4))
  
  # Treatment: T_0
  u_t0 = rnorm(n,0,0.05)
  eta_t0 <- -2.5 + x0*beta.x0_t0 + u_t0
  cuts_t0 <- c(-1.25, .45)
  
  probs_t0 <- simOLOG(eta_t0, cuts_t0)
  summary(probs_t0)
  t0 <- apply(probs_t0,1,function(x) sample(0:2,1,prob = x))
  
  # Second stage: X_1
  u_x1 = rnorm(n,0,0.035)
  
  # X1 under control and treatment
  # Under t0=0
  eta_x1_0 <- -2.5 + x0*0.5 + 0*1.1 - u_x1
  probs_x1_0 <- ilogit(eta_x1_0)
  x1_0 <- sapply(probs_x1_0,function(x) sample(0:1,1,prob=c(1-x,x)))
  
  # Under t0=1
  eta_x1_1 <- -2.5 + x0*0.5 + 1*1.1 - u_x1
  probs_x1_1 <- ilogit(eta_x1_1)
  x1_1 <- sapply(probs_x1_1,function(x) sample(0:1,1,prob=c(1-x,x)))
  
  # Under t0=2
  eta_x1_2 <- -2.5 + x0*0.5 + 2*1.1 - u_x1
  probs_x1_2 <- ilogit(eta_x1_2)
  x1_2 <- sapply(probs_x1_2,function(x) sample(0:1,1,prob=c(1-x,x)))
  
  x1 <- ifelse(t0==0,x1_0,ifelse(t0==1, x1_1, x1_2))
  #table(x1)
  
  # Treatment: T_1
  u_t1 = rnorm(n,0,0.05)
  cuts_t1 <- c(-1.05, .65)
  
  eta_t1_0 <- -3.5 + x0*beta.x0_t1 + x1_0*0.6 + 0*1 + u_t1
  probs_t1_0 <- simOLOG(eta_t1_0, cuts_t1)
  t1_0 <- apply(probs_t1_0,1,function(x) sample(0:2,1,prob = x)) 
  
  eta_t1_1 <- -3.5 + x0*beta.x0_t1 + x1_1*0.6 + 1*1 + u_t1
  probs_t1_1 <- simOLOG(eta_t1_1, cuts_t1)
  t1_1 <- apply(probs_t1_1,1,function(x) sample(0:2,1,prob = x)) 
  
  eta_t1_2 <- -3.5 + x0*beta.x0_t1 + x1_2*0.6 + 2*1 + u_t1
  probs_t1_2 <- simOLOG(eta_t1_2, cuts_t1)
  t1_2 <- apply(probs_t1_2,1,function(x) sample(0:2,1,prob = x)) 
  
  t1 <- ifelse(t0==0,t1_0, ifelse(t0==1,t1_1, t1_2))
  
  # Potential outcomes
  u_y = rnorm(n,0,0.04)
  # T0=0 & T1=0
  eta_y00 <- -3 + x0*0.2 + 0*1.5 + x1_0*0.4 + 0.2*0 + u_y
  probs_y00 <- ilogit(eta_y00)
  y00 <- sapply(probs_y00,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # T0=0 & T1=1
  eta_y01 <- -3 + x0*0.2 + 0*1.5 + x1_0*0.4 + 0.2*1 + u_y
  probs_y01 <- ilogit(eta_y01)
  y01 <- sapply(probs_y01,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # T0=0 & T1=2
  eta_y02 <- -3 + x0*0.2 + 0*1.5 + x1_0*0.4 + 0.2*2 + u_y
  probs_y02 <- ilogit(eta_y02)
  y02 <- sapply(probs_y02,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # T0=1 & T1=0
  eta_y10 <- -3 + x0*0.2 + 1*1.5 + x1_1*0.4 + 0.2*0 + u_y
  probs_y10 <- ilogit(eta_y10)
  y10 <- sapply(probs_y10,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # T0=1 & T1=1
  eta_y11 <- -3 + x0*0.2 + 1*1.5 + x1_1*0.4 + 0.2*1 + u_y
  probs_y11 <- ilogit(eta_y11)
  y11 <- sapply(probs_y11,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # T0=1 & T1=2
  eta_y12 <- -3 + x0*0.2 + 1*1.5 + x1_1*0.4 + 0.2*2 + u_y
  probs_y12 <- ilogit(eta_y12)
  y12 <- sapply(probs_y12,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # T0=2 & T1=0
  eta_y20 <- -3 + x0*0.2 + 2*1.5 + x1_2*0.4 + 0.2*0 + u_y
  probs_y20 <- ilogit(eta_y20)
  y20 <- sapply(probs_y20,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # T0=2 & T1=1
  eta_y21 <- -3 + x0*0.2 + 2*1.5 + x1_2*0.4 + 0.2*1 +  u_y
  probs_y21 <- ilogit(eta_y21)
  y21 <- sapply(probs_y21,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # T0=2 & T1=2
  eta_y22 <- -3 + x0*0.2 + 2*1.5 + x1_2*0.4 + 0.2*2 + u_y
  probs_y22 <- ilogit(eta_y22)
  y22 <- sapply(probs_y22,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # And Y based on observed assignment
  y <- ifelse((t0==0 & t1==0), y00, ifelse((t0==0 & t1==1), y01, ifelse((t0==0 & t1==2), y02,
                                                                        ifelse((t0==1 & t1==0), y10, ifelse((t0==1 & t1==1), y11, ifelse((t0==1 & t1==2), y12,
                                                                                                                                         ifelse((t0==2 & t1==0), y20, ifelse((t0==2 & t1==1), y21, y22))))))))
  
  
  # Some quick clean-up of the data to avoid breaking the functions
  simdat <- data.frame(y, t0=as.factor(t0), t1=as.factor(t1), x0, x1)
  simdat$t0.rec <- as.numeric(as.factor(simdat$t0))
  simdat$t1.rec <- as.numeric(as.factor(simdat$t1))
  
  ########## ESTIMATE REAL PROBABILITIES AND NAIVE MODEL
  real_probs <- c((table(y00)/n)[2],(table(y01)/n)[2],(table(y02)/n)[2],
                  (table(y10)/n)[2],(table(y11)/n)[2],(table(y12)/n)[2],
                  (table(y20)/n)[2],(table(y21)/n)[2],(table(y22)/n)[2])
  ests_bw_saturate <- summary(glm(y ~ as.factor(t0) + as.factor(t1) + as.factor(t0):as.factor(t1) + x0 + x1,
                                  data=simdat))$coefficients[,1]
  ########## ESTIMATE BOOTWEIGHTS OBJECT WITH WEITGTHS AND COEFFICIENTS OUTPUT
  # List ologit
  temp_ologit <- bootWeights(simdat,matrix(c('as.factor(t0) ~ x0','as.factor(t0) ~ 1',
                                             'as.factor(t1) ~ x1+as.factor(t0)+x0', 'as.factor(t1) ~ as.factor(t0)'), ncol=2, byrow = TRUE),
                             'y ~ as.factor(t0) + as.factor(t1) + as.factor(t0):as.factor(t1)', 'ologit', 
                             treat_ind = c(2,3), num_knots=3, num_cats=3, indices=1:n, out='both')
  # Beta ologit
  ests_bw_ologit <- temp_ologit[[2]]
  # Weights ologit
  #wts_ologit <- rbind(wts_ologit,temp_ologit[[1]])
  
  probs_bw_ologit <- getprpr(as.numeric(ests_bw_ologit))
  bias_bw_ologit <- probs_bw_ologit-real_probs
  cdes_bw_ologit <- getcdes(probs_bw_ologit)
  
  probs_bw_sat <- getprpr(as.numeric(ests_bw_saturate),2)
  bias_bw_sat <- probs_bw_sat-real_probs
  cdes_bw_saturate <- getcdes(probs_bw_sat)
  
  cdes_true = getcdes(real_probs)
  
  return(list(cdes_true, cdes_bw_saturate, cdes_bw_ologit))
}


simdatnew <- function(n=1000,
                      beta.x0_t0=1,
                      beta.x0_t1=beta.x0_t0/2,
                      beta.w1_t1=0.6,
                      beta.w1_y=beta.w1_t1/2){
  
  # Sample size and IDs
  id <- 1:n
  
  #### SIMULATE DATA
  # First stage: X_0
  x0 <- sample(0:1, n, replace=TRUE, prob=c(0.6,0.4))
  
  # Treatment: T_0
  u_t0 = rnorm(n,0,0.05)
  eta_t0 <- -2.5 + x0*beta.x0_t0 + u_t0
  cuts_t0 <- c(-1.25, .45)
  
  probs_t0 <- simOLOG(eta_t0, cuts_t0)
  summary(probs_t0)
  t0 <- apply(probs_t0,1,function(x) sample(0:2,1,prob = x))
  
  # Second stage: X_1
  u_x1 = rnorm(n,0,0.035)
  
  # X1 under control and treatment
  # Under t0=0
  eta_x1_0 <- -2.5 + x0*0.5 + 0*1.1 - u_x1
  probs_x1_0 <- ilogit(eta_x1_0)
  x1_0 <- sapply(probs_x1_0,function(x) sample(0:1,1,prob=c(1-x,x)))
  
  # Under t0=1
  eta_x1_1 <- -2.5 + x0*0.5 + 1*1.1 - u_x1
  probs_x1_1 <- ilogit(eta_x1_1)
  x1_1 <- sapply(probs_x1_1,function(x) sample(0:1,1,prob=c(1-x,x)))
  
  # Under t0=2
  eta_x1_2 <- -2.5 + x0*0.5 + 2*1.1 - u_x1
  probs_x1_2 <- ilogit(eta_x1_2)
  x1_2 <- sapply(probs_x1_2,function(x) sample(0:1,1,prob=c(1-x,x)))
  
  x1 <- ifelse(t0==0,x1_0,ifelse(t0==1, x1_1, x1_2))
  #table(x1)
  
  # Second stage: W_1
  u_w1 = rnorm(n,0,0.035)
  
  # W1 under control and treatment
  # Under t0=0
  eta_w1_0 <- -2.5 + x0*0.5 + 0*-1.5 - u_w1
  probs_w1_0 <- ilogit(eta_w1_0)
  w1_0 <- sapply(probs_w1_0,function(x) sample(0:1,1,prob=c(1-x,x)))
  
  # Under t0=1
  eta_w1_1 <- -2.5 + x0*0.5 + 1*-1.5 - u_w1
  probs_w1_1 <- ilogit(eta_w1_1)
  w1_1 <- sapply(probs_w1_1,function(x) sample(0:1,1,prob=c(1-x,x)))
  
  # Under t0=2
  eta_w1_2 <- -2.5 + x0*0.5 + 2*-1.5 - u_w1
  probs_w1_2 <- ilogit(eta_w1_2)
  w1_2 <- sapply(probs_w1_2,function(x) sample(0:1,1,prob=c(1-x,x)))
  
  w1 <- ifelse(t0==0,w1_0,ifelse(t0==1, w1_1, w1_2))
  #table(x1)
  
  # Treatment: T_1
  u_t1 = rnorm(n,0,0.5)
  cuts_t1 <- c(-1.05, .65)
  
  eta_t1_0 <- -3.5 + x0*beta.x0_t1 + x1_0*0.6 + w1_0*beta.w1_t1 + 0*1 + u_t1
  probs_t1_0 <- simOLOG(eta_t1_0, cuts_t1)
  t1_0 <- apply(probs_t1_0,1,function(x) sample(0:2,1,prob = x)) 
  
  eta_t1_1 <- -3.5 + x0*beta.x0_t1 + x1_1*0.6 + w1_1*beta.w1_t1 +  1*1 + u_t1
  probs_t1_1 <- simOLOG(eta_t1_1, cuts_t1)
  t1_1 <- apply(probs_t1_1,1,function(x) sample(0:2,1,prob = x)) 
  
  eta_t1_2 <- -3.5 + x0*beta.x0_t1 + x1_2*0.6 + w1_2*beta.w1_t1 + 2*1 + u_t1
  probs_t1_2 <- simOLOG(eta_t1_2, cuts_t1)
  t1_2 <- apply(probs_t1_2,1,function(x) sample(0:2,1,prob = x)) 
  
  t1 <- ifelse(t0==0,t1_0, ifelse(t0==1,t1_1, t1_2))
  
  # Potential outcomes
  u_y = rnorm(n,0,0.4)
  # T0=0 & T1=0
  eta_y00 <- -3 + x0*0.2 + 0*1.5 + x1_0*0.4 + w1_0*beta.w1_y + 0.2*0 + u_y
  probs_y00 <- ilogit(eta_y00)
  y00 <- sapply(probs_y00,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # T0=0 & T1=1
  eta_y01 <- -3 + x0*0.2 + 0*1.5 + x1_0*0.4 + w1_0*beta.w1_y + 0.2*1 + u_y
  probs_y01 <- ilogit(eta_y01)
  y01 <- sapply(probs_y01,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # T0=0 & T1=2
  eta_y02 <- -3 + x0*0.2 + 0*1.5 + x1_0*0.4 + w1_0*beta.w1_y + 0.2*2 + u_y
  probs_y02 <- ilogit(eta_y02)
  y02 <- sapply(probs_y02,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # T0=1 & T1=0
  eta_y10 <- -3 + x0*0.2 + 1*1.5 + x1_1*0.4 + w1_1*beta.w1_y + 0.2*0 + u_y
  probs_y10 <- ilogit(eta_y10)
  y10 <- sapply(probs_y10,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # T0=1 & T1=1
  eta_y11 <- -3 + x0*0.2 + 1*1.5 + x1_1*0.4 + w1_1*beta.w1_y + 0.2*1 + u_y
  probs_y11 <- ilogit(eta_y11)
  y11 <- sapply(probs_y11,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # T0=1 & T1=2
  eta_y12 <- -3 + x0*0.2 + 1*1.5 + x1_1*0.4 + w1_1*beta.w1_y + 0.2*2 + u_y
  probs_y12 <- ilogit(eta_y12)
  y12 <- sapply(probs_y12,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # T0=2 & T1=0
  eta_y20 <- -3 + x0*0.2 + 2*1.5 + x1_2*0.4 + w1_2*beta.w1_y + 0.2*0 + u_y
  probs_y20 <- ilogit(eta_y20)
  y20 <- sapply(probs_y20,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # T0=2 & T1=1
  eta_y21 <- -3 + x0*0.2 + 2*1.5 + x1_2*0.4 + w1_2*beta.w1_y + 0.2*1 +  u_y
  probs_y21 <- ilogit(eta_y21)
  y21 <- sapply(probs_y21,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # T0=2 & T1=2
  eta_y22 <- -3 + x0*0.2 + 2*1.5 + x1_2*0.4 + w1_2*beta.w1_y + 0.2*2 + u_y
  probs_y22 <- ilogit(eta_y22)
  y22 <- sapply(probs_y22,function(x) sample(0:1, 1, prob=c(1-x,x)))
  
  # And Y based on observed assignment
  y <- ifelse((t0==0 & t1==0), y00, ifelse((t0==0 & t1==1), y01, ifelse((t0==0 & t1==2), y02,
                                                                        ifelse((t0==1 & t1==0), y10, ifelse((t0==1 & t1==1), y11, ifelse((t0==1 & t1==2), y12,
                                                                                                                                         ifelse((t0==2 & t1==0), y20, ifelse((t0==2 & t1==1), y21, y22))))))))
  
  
  # Some quick clean-up of the data to avoid breaking the functions
  simdat <- data.frame(y, t0=as.factor(t0), t1=as.factor(t1), x0, x1)
  simdat$t0.rec <- as.numeric(as.factor(simdat$t0))
  simdat$t1.rec <- as.numeric(as.factor(simdat$t1))
  
  ########## ESTIMATE REAL PROBABILITIES AND NAIVE MODEL
  real_probs <- c((table(y00)/1000)[2],(table(y01)/1000)[2],(table(y02)/1000)[2],
                  (table(y10)/1000)[2],(table(y11)/1000)[2],(table(y12)/1000)[2],
                  (table(y20)/1000)[2],(table(y21)/1000)[2],(table(y22)/1000)[2])
  ests_bw_saturate <- summary(glm(y ~ as.factor(t0) + as.factor(t1) + as.factor(t0):as.factor(t1) + x0 + x1,
                                  data=simdat))$coefficients[,1]
  ########## ESTIMATE BOOTWEIGHTS OBJECT WITH WEITGTHS AND COEFFICIENTS OUTPUT
  # List ologit
  temp_ologit <- bootWeights(simdat,matrix(c('as.factor(t0) ~ x0','as.factor(t0) ~ 1',
                                             'as.factor(t1) ~ x1+as.factor(t0)+x0', 'as.factor(t1) ~ as.factor(t0)'), ncol=2, byrow = TRUE),
                             'y ~ as.factor(t0) + as.factor(t1) + as.factor(t0):as.factor(t1)', 'ologit', 
                             treat_ind = c(2,3), num_knots=3, num_cats=3, indices=1:n, out='both')
  # Beta ologit
  ests_bw_ologit <- temp_ologit[[2]]
  
  probs_bw_ologit <- getprpr(as.numeric(ests_bw_ologit))
  bias_bw_ologit <- probs_bw_ologit-real_probs
  cdes_bw_ologit <- getcdes(probs_bw_ologit)
  
  probs_bw_sat <- getprpr(as.numeric(ests_bw_saturate),2)
  bias_bw_sat <- probs_bw_sat-real_probs
  cdes_bw_saturate <- getcdes(probs_bw_sat)
  
  cdes_true = getcdes(real_probs)
  
  return(list(cdes_true, cdes_bw_saturate, cdes_bw_ologit))
}